/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "../../../common/data_utils.h"
#ifndef ASCENDC_CPU_DEBUG
#include "acl/acl.h"
extern void layernorm_grad_custom_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *inputXGm, uint8_t *inputDyGm,
    uint8_t *inputVarianceGm, uint8_t *inputMeanGm, uint8_t *inputGammaGm, uint8_t *ouputPdXGm,
    uint8_t *resForGammaGm, uint8_t *workspace, uint8_t *tiling);
#else
#include "tikicpulib.h"
extern "C" __global__ __aicore__ void layernorm_grad_custom(GM_ADDR inputXGm, GM_ADDR inputDyGm, GM_ADDR inputVarianceGm,
    GM_ADDR inputMeanGm, GM_ADDR inputGammaGm, GM_ADDR outputPdXGm, GM_ADDR resForGammaGm,
    GM_ADDR workspace, GM_ADDR tiling);
#endif

constexpr uint32_t BLENGTH = 2;
constexpr uint32_t SLENGTH = 32;
constexpr uint32_t HLENGTH = 16;
constexpr uint32_t BLOCK_DIM = 1;
constexpr uint32_t TILINGDATA_SIZE = 31;
constexpr uint32_t WORKSPACE_SIZE = 16*1024*1024;

extern uint8_t *GenerateTiling(uint32_t bLength, uint32_t sLength, uint32_t hLength);

static bool CompareResult(const void *outputData, int64_t outSize, std::string goldenName)
{
    void *goldenData;
#ifdef ASCENDC_CPU_DEBUG
    goldenData = (uint8_t *)AscendC::GmAlloc(outSize);
#else
    CHECK_ACL(aclrtMallocHost((void **)(&goldenData), outSize));
#endif
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden_%s.bin success!\n", goldenName.c_str());
    } else {
#ifdef ASCENDC_CPU_DEBUG
        AscendC::GmFree((void *)goldenData);
#else
        CHECK_ACL(aclrtFreeHost(goldenData));
#endif
        return false;
    }
    constexpr float EPS = 1e-4;
    int64_t wrongNum = 0;

    for (int i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float *>(outputData))[i];
        float b = (reinterpret_cast<const float *>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden_%s.bin failed output is %lf, golden is %lf\n", goldenName.c_str(), a, b);
            wrongNum++;
        }
    }
#ifdef ASCENDC_CPU_DEBUG
    AscendC::GmFree((void *)goldenData);
#else
    CHECK_ACL(aclrtFreeHost(goldenData));
#endif
    if (wrongNum != 0) {
        return false;
    } else {
        printf("CompareResult golden_%s.bin success!\n", goldenName.c_str());
        return true;
    }
}

int32_t main(int32_t argc, char *argv[])
{
    uint32_t blockDim = BLOCK_DIM;
    size_t workspaceSize = WORKSPACE_SIZE;
    size_t inputSize_inputX = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
    size_t inputSize_inputDy = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
    size_t inputSize_inputMean = BLENGTH * SLENGTH * sizeof(float);
    size_t inputSize_inputVariance = BLENGTH * SLENGTH * sizeof(float);
    size_t inputSize_inputGamma = HLENGTH * sizeof(float);
    size_t outputSize_outputPdX = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
    size_t outputSize_resForGamma = BLENGTH * SLENGTH * HLENGTH * sizeof(float);
    size_t tilingFileSize = TILINGDATA_SIZE * sizeof(uint32_t);
    uint8_t *tilingBuf = GenerateTiling(BLENGTH, SLENGTH, HLENGTH);

#ifdef ASCENDC_CPU_DEBUG
    uint8_t *inputX = (uint8_t *)AscendC::GmAlloc(inputSize_inputX);
    uint8_t *inputDy = (uint8_t *)AscendC::GmAlloc(inputSize_inputDy);
    uint8_t *inputMean = (uint8_t *)AscendC::GmAlloc(inputSize_inputMean);
    uint8_t *inputVariance = (uint8_t *)AscendC::GmAlloc(inputSize_inputVariance);
    uint8_t *inputGamma = (uint8_t *)AscendC::GmAlloc(inputSize_inputGamma);
    uint8_t *outputPdX = (uint8_t *)AscendC::GmAlloc(outputSize_outputPdX);
    uint8_t *resForGamma = (uint8_t *)AscendC::GmAlloc(outputSize_resForGamma);
    uint8_t *workspace = (uint8_t *)AscendC::GmAlloc(workspaceSize);
    uint8_t *tiling = (uint8_t *)AscendC::GmAlloc(tilingFileSize);

    ReadFile("../input/input_inputX.bin", inputSize_inputX, inputX, inputSize_inputX);
    ReadFile("../input/input_inputDy.bin", inputSize_inputDy, inputDy, inputSize_inputDy);
    ReadFile("../input/input_inputMean.bin", inputSize_inputMean, inputMean, inputSize_inputMean);
    ReadFile("../input/input_inputVariance.bin", inputSize_inputVariance, inputVariance, inputSize_inputVariance);
    ReadFile("../input/input_inputGamma.bin", inputSize_inputGamma, inputGamma, inputSize_inputGamma);

    memcpy_s(tiling, tilingFileSize, tilingBuf, tilingFileSize);

    AscendC::SetKernelMode(KernelMode::AIV_MODE);
    ICPU_RUN_KF(layernorm_grad_custom, blockDim, inputX, inputDy, inputVariance, inputMean, inputGamma,
        outputPdX, resForGamma, workspace, tiling); // use this macro for cpu debug

    WriteFile("../output/output_outputPdX.bin", outputPdX, outputSize_outputPdX);
    WriteFile("../output/output_resForGamma.bin", resForGamma, outputSize_resForGamma);
    bool goldenResult = true;
    goldenResult &= CompareResult(outputPdX, outputSize_outputPdX, "outputPdX");
    goldenResult &= CompareResult(resForGamma, outputSize_resForGamma, "resForGamma");
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }

    AscendC::GmFree((void *)inputX);
    AscendC::GmFree((void *)inputDy);
    AscendC::GmFree((void *)inputMean);
    AscendC::GmFree((void *)inputVariance);
    AscendC::GmFree((void *)inputGamma);
    AscendC::GmFree((void *)outputPdX);
    AscendC::GmFree((void *)resForGamma);
    AscendC::GmFree((void *)workspace);
    AscendC::GmFree((void *)tiling);
#else
    CHECK_ACL(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_ACL(aclrtSetDevice(deviceId));
    aclrtStream stream = nullptr;
    CHECK_ACL(aclrtCreateStream(&stream));

    uint8_t *inputXHost, *inputDyHost, *inputMeanHost, *inputVarianceHost, *inputGammaHost,
        *outputPdXHost, *resForGammaHost, *workspaceHost, *tilingHost;
    uint8_t *inputXDevice, *inputDyDevice, *inputMeanDevice, *inputVarianceDevice, *inputGammaDevice,
        *outputPdXDevice, *resForGammaDevice, *workspaceDevice, *tilingDevice;

    CHECK_ACL(aclrtMallocHost((void **)(&inputXHost), inputSize_inputX));
    CHECK_ACL(aclrtMallocHost((void **)(&inputDyHost), inputSize_inputDy));
    CHECK_ACL(aclrtMallocHost((void **)(&inputMeanHost), inputSize_inputMean));
    CHECK_ACL(aclrtMallocHost((void **)(&inputVarianceHost), inputSize_inputVariance));
    CHECK_ACL(aclrtMallocHost((void **)(&inputGammaHost), inputSize_inputGamma));
    CHECK_ACL(aclrtMallocHost((void **)(&outputPdXHost), outputSize_outputPdX));
    CHECK_ACL(aclrtMallocHost((void **)(&resForGammaHost), outputSize_resForGamma));
    CHECK_ACL(aclrtMallocHost((void **)(&workspaceHost), workspaceSize));
    CHECK_ACL(aclrtMallocHost((void **)(&tilingHost), tilingFileSize));
    CHECK_ACL(aclrtMalloc((void **)&inputXDevice, inputSize_inputX, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inputDyDevice, inputSize_inputDy, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inputMeanDevice, inputSize_inputMean, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inputVarianceDevice, inputSize_inputVariance, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inputGammaDevice, inputSize_inputGamma, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&outputPdXDevice, outputSize_outputPdX, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&resForGammaDevice, outputSize_resForGamma, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&workspaceDevice, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&tilingDevice, tilingFileSize, ACL_MEM_MALLOC_HUGE_FIRST));

    ReadFile("../input/input_inputX.bin", inputSize_inputX, inputXHost, inputSize_inputX);
    ReadFile("../input/input_inputDy.bin", inputSize_inputDy, inputDyHost, inputSize_inputDy);
    ReadFile("../input/input_inputMean.bin", inputSize_inputMean, inputMeanHost, inputSize_inputMean);
    ReadFile("../input/input_inputVariance.bin", inputSize_inputVariance, inputVarianceHost, inputSize_inputVariance);
    ReadFile("../input/input_inputGamma.bin", inputSize_inputGamma, inputGammaHost, inputSize_inputGamma);

    CHECK_ACL(aclrtMemcpy(workspaceDevice, workspaceSize, workspaceHost, workspaceSize, ACL_MEMCPY_HOST_TO_DEVICE));

    CHECK_ACL(aclrtMemcpy(tilingDevice, tilingFileSize, tilingBuf, tilingFileSize,
        ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inputXDevice, inputSize_inputX, inputXHost, inputSize_inputX,
        ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inputDyDevice, inputSize_inputDy, inputDyHost, inputSize_inputDy,
        ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inputMeanDevice, inputSize_inputMean, inputMeanHost, inputSize_inputMean,
        ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inputVarianceDevice, inputSize_inputVariance, inputVarianceHost, inputSize_inputVariance,
        ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inputGammaDevice, inputSize_inputGamma, inputGammaHost, inputSize_inputGamma,
        ACL_MEMCPY_HOST_TO_DEVICE));

    layernorm_grad_custom_do(blockDim, nullptr, stream, inputXDevice, inputDyDevice, inputVarianceDevice,
        inputMeanDevice, inputGammaDevice, outputPdXDevice, resForGammaDevice, workspaceDevice, tilingDevice);

    CHECK_ACL(aclrtSynchronizeStream(stream));

    CHECK_ACL(aclrtMemcpy(outputPdXHost, outputSize_outputPdX, outputPdXDevice, outputSize_outputPdX,
        ACL_MEMCPY_DEVICE_TO_HOST));
    CHECK_ACL(aclrtMemcpy(resForGammaHost, outputSize_resForGamma, resForGammaDevice, outputSize_resForGamma,
        ACL_MEMCPY_DEVICE_TO_HOST));

    WriteFile("../output/output_outputPdX.bin", outputPdXHost, outputSize_outputPdX);
    WriteFile("../output/output_resForGamma.bin", resForGammaHost, outputSize_resForGamma);
    bool goldenResult = true;
    goldenResult &= CompareResult(outputPdXHost, outputSize_outputPdX, "outputPdX");
    goldenResult &= CompareResult(resForGammaHost, outputSize_resForGamma, "resForGamma");
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }

    CHECK_ACL(aclrtFree(inputXDevice));
    CHECK_ACL(aclrtFree(inputDyDevice));
    CHECK_ACL(aclrtFree(inputMeanDevice));
    CHECK_ACL(aclrtFree(inputVarianceDevice));
    CHECK_ACL(aclrtFree(inputGammaDevice));
    CHECK_ACL(aclrtFree(outputPdXDevice));
    CHECK_ACL(aclrtFree(resForGammaDevice));
    CHECK_ACL(aclrtFree(workspaceDevice));
    CHECK_ACL(aclrtFree(tilingDevice));

    CHECK_ACL(aclrtFreeHost(inputXHost));
    CHECK_ACL(aclrtFreeHost(inputDyHost));
    CHECK_ACL(aclrtFreeHost(inputMeanHost));
    CHECK_ACL(aclrtFreeHost(inputVarianceHost));
    CHECK_ACL(aclrtFreeHost(inputGammaHost));
    CHECK_ACL(aclrtFreeHost(outputPdXHost));
    CHECK_ACL(aclrtFreeHost(resForGammaHost));
    CHECK_ACL(aclrtFreeHost(workspaceHost));
    CHECK_ACL(aclrtFreeHost(tilingHost));

    CHECK_ACL(aclrtDestroyStream(stream));
    CHECK_ACL(aclrtResetDevice(deviceId));
    CHECK_ACL(aclFinalize());
#endif
    free(tilingBuf);
    return 0;
}
